import contextlib
import functools
import itertools
import logging
import sys
import threading


class _State(threading.local):
    def __init__(self):
        super(_State, self).__init__()
        self.contexts = ()


_state = _State()


class StackContext(object):
    def __init__(self, context_factory):
        self.context_factory = context_factory

    def __enter__(self):
        self.old_contexts = _state.contexts
        # _state.contexts is a tuple of (class, arg) pairs
        _state.contexts = (self.old_contexts +
                           ((StackContext, self.context_factory),))
        try:
            self.context = self.context_factory()
            self.context.__enter__()
        except Exception:
            _state.contexts = self.old_contexts
            raise

    def __exit__(self, type, value, traceback):
        try:
            return self.context.__exit__(type, value, traceback)
        finally:
            _state.contexts = self.old_contexts


class ExceptionStackContext(object):
    def __init__(self, exception_handler):
        self.exception_handler = exception_handler

    def __enter__(self):
        self.old_contexts = _state.contexts
        _state.contexts = (self.old_contexts +
                           ((ExceptionStackContext, self.exception_handler),))

    def __exit__(self, type, value, traceback):
        try:
            if type is not None:
                return self.exception_handler(type, value, traceback)
        finally:
            _state.contexts = self.old_contexts


class NullContext(object):
    def __enter__(self):
        self.old_contexts = _state.contexts
        _state.contexts = ()

    def __exit__(self, type, value, traceback):
        _state.contexts = self.old_contexts


class _StackContextWrapper(functools.partial):
    pass


def wrap(fn):
    if fn is None or fn.__class__ is _StackContextWrapper:
        return fn
    # functools.wraps doesn't appear to work on functools.partial objects
    # @functools.wraps(fn)

    def wrapped(callback, contexts, *args, **kwargs):
        if contexts is _state.contexts or not contexts:
            callback(*args, **kwargs)
            return
        if not _state.contexts:
            new_contexts = [cls(arg) for (cls, arg) in contexts]
        elif (len(_state.contexts) > len(contexts) or
              any(a[1] is not b[1]
                  for a, b in itertools.izip(_state.contexts, contexts))):
            # contexts have been removed or changed, so start over
            new_contexts = ([NullContext()] +
                            [cls(arg) for (cls, arg) in contexts])
        else:
            new_contexts = [cls(arg)
                            for (cls, arg) in contexts[len(_state.contexts):]]
        if len(new_contexts) > 1:
            with _nested(*new_contexts):
                callback(*args, **kwargs)
        elif new_contexts:
            with new_contexts[0]:
                callback(*args, **kwargs)
        else:
            callback(*args, **kwargs)
    return _StackContextWrapper(wrapped, fn, _state.contexts)


@contextlib.contextmanager
def _nested(*managers):
    _exits = []
    _vars = []
    exc = (None, None, None)
    try:
        for mgr in managers:
            _exit = mgr.__exit__
            _enter = mgr.__enter__
            _vars.append(_enter())
            _exits.append(_exit)
        yield _vars
    except Exception as error:
        logging.info(error)
        exc = sys.exc_info()
    finally:
        while _exits:
            _exit = _exits.pop()
            try:
                if _exit(*exc):
                    exc = (None, None, None)
            except Exception as error:
                logging.info(error)
                exc = sys.exc_info()
        if exc != (None, None, None):
            # Don't rely on sys.exc_info() still containing
            # the right information. Another exception may
            # have been raised and caught by an exit method
            raise exc[0], exc[1], exc[2]
